#include "utility.h"

using namespace std;


unsigned __stdcall threadFun(void *arglist)
{
	para_data* p = (para_data *)arglist;
	uint64** genoY_G = p->genoY_G;
	int* pMarginalDistrSNP = p->pMarginalDistrSNP;
	int* pMarginalDistrSNP_Y = p->pMarginalDistrSNP_Y;
	int* nlongintY_G = p->nlongintY_G;

	int start = p->start_snp;
	int end = p->end_snp;
	int nsnps = p->nsnps;
	int nsamples = p->nsamples;
	int ncov = p->numofCov;
	double thresholdRecord = p->thresholdRecord;
	double interactionMeasure=0;
	int threadIdx = p->index;
	int newprogress = 0, oldprogress = 0;
	int numofTables = 1;
	int localGenoDistr[64 * NumOfCell];
	uint64 andResult = 0;
	int count = 0;

	numofTables = pow(2, ncov + 1);
	if (end <= start || start >= nsnps - 1 || end >= nsnps)
	{
		return 0;
	}
	
	for (int snp1 = start; snp1 <= end; snp1++)
	{
		for (int snp2 = snp1 + 1; snp2 < nsnps; snp2++)
		{

			for (int i = 0; i < 2; i++)
			{
				for (int j = 0; j < 2; j++)
				{
					for (int m = 0; m < numofTables; m++)
					{
						count = 0;
						for (int index = 0; index < nlongintY_G[m]; index++)
						{
							andResult = genoY_G[m][(index * 3 + i)*nsnps + snp1] & genoY_G[m][(index * 3 + j)*nsnps + snp2];
							count += popcount(andResult);
						}
						localGenoDistr[m*NumOfCell + i * 3 + j] = count;
					}
				}
			}


			for (int m = 0; m < numofTables; m++)
			{
				localGenoDistr[m*NumOfCell + 2] = pMarginalDistrSNP_Y[(0 * numofTables + m)*nsnps + snp1] - localGenoDistr[m*NumOfCell + 0] - localGenoDistr[m*NumOfCell + 1];
				localGenoDistr[m*NumOfCell + 5] = pMarginalDistrSNP_Y[(1 * numofTables + m)*nsnps + snp1] - localGenoDistr[m*NumOfCell + 3] - localGenoDistr[m*NumOfCell + 4];
				localGenoDistr[m*NumOfCell + 6] = pMarginalDistrSNP_Y[(0 * numofTables + m)*nsnps + snp2] - localGenoDistr[m*NumOfCell + 0] - localGenoDistr[m*NumOfCell + 3];
				localGenoDistr[m*NumOfCell + 7] = pMarginalDistrSNP_Y[(1 * numofTables + m)*nsnps + snp2] - localGenoDistr[m*NumOfCell + 1] - localGenoDistr[m*NumOfCell + 4];
				localGenoDistr[m*NumOfCell + 8] = pMarginalDistrSNP_Y[(2 * numofTables + m)*nsnps + snp2] - localGenoDistr[m*NumOfCell + 2] - localGenoDistr[m*NumOfCell + 5];
			}

			if (ncov == 1)
				interactionMeasure = postCorrection_1(localGenoDistr, 1) - postCorrection_1(localGenoDistr, 0);
			else if (ncov == 2)
				interactionMeasure = postCorrection_2(localGenoDistr, 1) - postCorrection_2(localGenoDistr, 0);
			else if (ncov == 3)
				interactionMeasure = postCorrection_3(localGenoDistr, 1) - postCorrection_3(localGenoDistr, 0);
			else if (ncov == 4)
				interactionMeasure = postCorrection_4(localGenoDistr, 1) - postCorrection_4(localGenoDistr, 0);
			else
				interactionMeasure = postCorrection_5(localGenoDistr, 1) - postCorrection_5(localGenoDistr, 0);

			if (interactionMeasure > thresholdRecord)
			{
				(p->Partial_InterPairs)->push_back(make_pair(snp1, snp2));
				(p->Partial_InterMeasure)->push_back(interactionMeasure);
			}

			if (threadIdx == 0)
			{
				newprogress = (int)(((double)snp1 / (double)end) * 100);
				if (newprogress != oldprogress)
				{
					printf("\rProgress: %d%%", newprogress);
					oldprogress = newprogress;
				}
			}
		}
	}

	return 1;
}


void GetInteractionPairs(uint64** genoY_G, int nsnps, int nsamples, int numOfCov, int* nlongintY_G, int* nY_G, int* pMarginalDistrSNP, int* pMarginalDistrSNP_Y,
	const unsigned char* wordbits, int wordBitCount, vector<pair<int,int>>&interactionPairs, vector<double>&interactionMeasure, double thresholdRecord,int numofThread)
{
	printf("Start calculation interaction...\n");

	vector<para_data*>paras;
	vector<unsigned int> threadID(numofThread);
	vector<HANDLE>hth(numofThread);
	vector<DWORD>exitcode(numofThread);
	vector<double>divPoint(numofThread, 0);
	int preThreadSNPEnd = -1;
	
	//decide the task dividing point for threads
	for (int i = 0; i < numofThread; i++)
	{
		divPoint[i] = 1 - sqrt((double)(numofThread - i - 1) / numofThread);
	}

	
	for (int i = 0; i < numofThread; i++)
	{
		para_data* tmp = new para_data();
		tmp->Partial_InterMeasure = new vector<double>();
		tmp->Partial_InterPairs = new vector<pair<int, int> >();
		paras.push_back(tmp);
	}

	for (int i = 0; i < numofThread; i++)
	{
		paras[i]->genoY_G = genoY_G;
		paras[i]->pMarginalDistrSNP = pMarginalDistrSNP;
		paras[i]->pMarginalDistrSNP_Y = pMarginalDistrSNP_Y;
		paras[i]->nlongintY_G = nlongintY_G;
		paras[i]->nsamples = nsamples;
		paras[i]->nsnps = nsnps;
		paras[i]->thresholdRecord = thresholdRecord;
		paras[i]->numofCov = numOfCov;
		paras[i]->index = i;
	}

	for (int i = 0; i < numofThread; i++)
	{
		paras[i]->start_snp = preThreadSNPEnd + 1;
		paras[i]->end_snp = (int)nsnps*divPoint[i];
		preThreadSNPEnd = paras[i]->end_snp;
	}

	paras[numofThread - 1]->end_snp = nsnps - 1;

	for (int i = 0; i < numofThread; i++)
	{
		hth[i] = (HANDLE)_beginthreadex(NULL, 0, threadFun, paras[i], 0, &threadID[i]);
	}

	for (int i = 0; i < numofThread; i++)
	{
		WaitForSingleObject(hth[i], INFINITE);
		GetExitCodeThread(hth[i], &exitcode[i]);
		CloseHandle(hth[i]);
	}

	for (int i = 0; i<numofThread; i++)
	{
		for (int j = 0; j<paras[i]->Partial_InterPairs->size(); j++)
		{

			interactionPairs.push_back(paras[i]->Partial_InterPairs->at(j));
			interactionMeasure.push_back(paras[i]->Partial_InterMeasure->at(j));
		}
	}

	return;
}